#############################################################
# Author: Mike Burnham, mlb6496@psu.edu
# Python: 3.11.5
# OS: Windows 10
#
# Notes: This script reproduces tables 3, 4, and 5.
##############################################################

import pandas as pd
from tabulate import tabulate
from sklearn.metrics import accuracy_score as acc, matthews_corrcoef as mcc, f1_score as f1

import logging

logging.basicConfig(level=logging.DEBUG, filename="tables.log", filemode="a+", format="%(asctime)-15s %(levelname)-8s %(message)s")

###########
## Table 3
###########
supervised = pd.read_csv('./trump_twitter_supervised.csv')
true = supervised['labels']
# Logistic Regression
mcc_logi = round(mcc(true, supervised['logistic']), 2)
f1_logi = round(f1(true, supervised['logistic']), 2)
acc_logi = round(acc(true, supervised['logistic']), 2)

# Random Forest
mcc_for = round(mcc(true, supervised['forest']), 2)
f1_for = round(f1(true, supervised['forest']), 2)
acc_for = round(acc(true, supervised['forest']), 2)

# SVM
mcc_svm = round(mcc(true, supervised['svm']), 2)
f1_svm = round(f1(true, supervised['svm']), 2)
acc_svm = round(acc(true, supervised['svm']), 2)

# RoBERTa
mcc_rob = round(mcc(true, supervised['roberta']), 2)
f1_rob = round(f1(true, supervised['roberta']), 2)
acc_rob = round(acc(true, supervised['roberta']), 2)

# BERTweet
mcc_tweet = round(mcc(true, supervised['bertweet']), 2)
f1_tweet = round(f1(true, supervised['bertweet']), 2)
acc_tweet = round(acc(true, supervised['bertweet']), 2)

# PoliBERTweet
mcc_poli = round(mcc(true, supervised['polibert']), 2)
f1_poli = round(f1(true, supervised['polibert']), 2)
acc_poli = round(acc(true, supervised['polibert']), 2)

# DeBERTa
mcc_deb = round(mcc(true, supervised['deberta']), 2)
f1_deb = round(f1(true, supervised['deberta']), 2)
acc_deb = round(acc(true, supervised['deberta']), 2)


headers = ["", "Model", "MCC", "F1", "Accuracy", "Sweep Time", "Hardware"]
data = [
    ["Bag-of-Words Classifiers", "Logistic Regression", mcc_logi, f1_logi, acc_logi, "1.53s", "CPU"],
    ["", "Random Forest", mcc_for, f1_for, acc_for, "2 min. 2s", "CPU"],
    ["", "SVM", mcc_svm, f1_svm, acc_svm, "37.8s", "CPU"],
    ["Language Models", "RoBERTa", mcc_rob, f1_rob, acc_rob, "46 min", "GPU"],
    ["", "BERTweet", mcc_tweet, f1_tweet, acc_tweet, "44 min", "GPU"],
    ["", "PoliBERTweet", mcc_poli, f1_poli, acc_poli, "42 min", "GPU"],
]

table = tabulate(data, headers, colalign=("center", "left", "right", "right", "right", "right", "right"))

logging.info("Table 3\n" + str(table))

###########
## Table 4
###########
nli = pd.read_csv('./trump_test_nli.csv')
true = nli['adjudicated_sup']

# Small
mcc_s = round(mcc(true, nli['gS']), 2)
f1_s = round(f1(true, nli['gS']), 2)
acc_s = round(acc(true, nli['gS']), 2)

# Base
mcc_b = round(mcc(true, nli['gB']), 2)
f1_b = round(f1(true, nli['gB']), 2)
acc_b = round(acc(true, nli['gB']), 2)

# Large
mcc_l = round(mcc(true, nli['gL']), 2)
f1_l = round(f1(true, nli['gL']), 2)
acc_l = round(acc(true, nli['gL']), 2)

headers = ["Model", "MCC", "F1", "Accuracy", "Inference Time (GPU)", "Inference Time (CPU)"]
data = [
    ["DeBERTaV3 Small", mcc_s, f1_s, acc_s, "3.06s", "2 min. 37s"],
    ["DeBERTaV3 Base", mcc_b, f1_b, acc_b, "4.98s", "5 min. 8s"],
    ["DeBERTaV3 Large", mcc_l, f1_l, acc_l, "13.5s", "16 min. 35s"],

]
table = tabulate(data, headers, colalign=("left", "right", "right", "right", "right", "right"))

logging.info("Table 4\n" + str(table))

###########
## Table 5
###########
context = pd.read_csv('./trump_test_in_context.csv')
true = context['adjudicated_sup']

# GPT-4 Standard Prompt
mcc_4s = round(mcc(true, context['gpt4_nobias']), 2)
f1_4s = round(f1(true, context['gpt4_nobias']), 2)
acc_4s = round(acc(true, context['gpt4_nobias']), 2)

# GPT-3.5 Standard Prompt
mcc_3s = round(mcc(true, context['gpt3_5_nobias']), 2)
f1_3s = round(f1(true, context['gpt3_5_nobias']), 2)
acc_3s = round(acc(true, context['gpt3_5_nobias']), 2)

# mistral Standard Prompt
mcc_mistrals = round(mcc(true, context['mistral_nobias']), 2)
f1_mistrals = round(f1(true, context['mistral_nobias']), 2)
acc_mistrals = round(acc(true, context['mistral_nobias']), 2)

# GPT-4 Chain of Thought Prompt
mcc_4cot = round(mcc(true, context['gpt4_cot']), 2)
f1_4cot = round(f1(true, context['gpt4_cot']), 2)
acc_4cot = round(acc(true, context['gpt4_cot']), 2)

# GPT-3.5 Chain of Thought Prompt
mcc_3cot = round(mcc(true, context['gpt3_5_cot']), 2)
f1_3cot = round(f1(true, context['gpt3_5_cot']), 2)
acc_3cot = round(acc(true, context['gpt3_5_cot']), 2)

# mistral Chain of Thought Prompt
mcc_mistralcot = round(mcc(true, context['mistral_cot']), 2)
f1_mistralcot = round(f1(true, context['mistral_cot']), 2)
acc_mistralcot = round(acc(true, context['mistral_cot']), 2)

# GPT-4 Logit Bias Prompt
mcc_4b = round(mcc(true, context['gpt4_bias']), 2)
f1_4b = round(f1(true, context['gpt4_bias']), 2)
acc_4b = round(acc(true, context['gpt4_bias']), 2)

# GPT-3.5 Logit Bias Prompt
mcc_3b = round(mcc(true, context['gpt3_5_bias']), 2)
f1_3b = round(f1(true, context['gpt3_5_bias']), 2)
acc_3b = round(acc(true, context['gpt3_5_bias']), 2)

# mistral Logit Bias Prompt
mcc_mistralb = round(mcc(true, context['mistral_bias']), 2)
f1_mistralb = round(f1(true, context['mistral_bias']), 2)
acc_mistralb = round(acc(true, context['mistral_bias']), 2)


headers = ["Prompt", "Model", "MCC", "F1", "Accuracy", "Inference Time (GPU)", "Cost"]
data = [
    ["Standard", "GPT-4", mcc_4s, f1_4s, acc_4s, "-", "$12.52"],
    ["", "GPT-3.5 Turbo", mcc_3s, f1_3s, acc_3s, "-", "$0.68"],
    ["", "Mistral 7B", mcc_mistrals, f1_mistrals, acc_mistrals, "2 min. 32s", "-"],
    ["Chain of Thought", "GPT-4", mcc_4cot, f1_4cot, acc_4cot, "-", "$38.62"],
    ["", "GPT-3.5 Turbo", mcc_3cot, f1_3cot, acc_3cot, "-", "$1.26"],
    ["", "Mistral 7B", mcc_mistralcot, f1_mistralcot, acc_mistralcot, "1 hr. 50 min.", "-"],
    ["Logit Bias", "GPT-4", mcc_4b, f1_4b, acc_4b, "-", "$12.52"],
    ["", "GPT-3.5 Turbo", mcc_3b, f1_3b, acc_3b, "-", "$0.68"],
    ["", "Mistral 7B", mcc_mistralb, f1_mistralb, acc_mistralb, "2 min. 33s", "-"]
]

table = tabulate(data, headers, colalign=("center", "left", "right", "right", "right", "right", "right"))

logging.info("Table 5\n" + str(table))